{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9aea2547",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "from sklearn import datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c085a32e",
   "metadata": {},
   "outputs": [],
   "source": [
    "twenty_train = datasets.load_files(\"data/20news-bydate/20news-bydate-train\")\n",
    "twenty_test = datasets.load_files(\"data/20news-bydate/20news-bydate-test\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a881ec2e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(20, 11314, 11314, 7532)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(twenty_train.target_names),len(twenty_train.data),len(twenty_train.filenames),len(twenty_test.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "231dab52",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "From: cubbie@garnet.berkeley.edu (                               )\n",
      "Subject: Re: Cubs behind Marlins? How?\n",
      "Article-I.D.: agate.1pt592$f9a\n"
     ]
    }
   ],
   "source": [
    "print(\"\\n\".join(twenty_train.data[0].decode().strip().split(\"\\n\")[:3]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "2fccca67",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "rec.sport.baseball\n"
     ]
    }
   ],
   "source": [
    "print(twenty_train.target_names[twenty_train.target[0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "73dfc18d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 9,  4, 11,  4,  0,  4,  5,  5, 13, 12])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "twenty_train.target[:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a530b82a",
   "metadata": {},
   "source": [
    "# 计算词频"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "90857aed",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(11314, 129782)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.feature_extraction.text import CountVectorizer\n",
    "count_vect = CountVectorizer(stop_words=\"english\",decode_error='ignore')\n",
    "X_train_counts = count_vect.fit_transform(twenty_train.data)\n",
    "X_train_counts.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49598f92",
   "metadata": {},
   "source": [
    "# 使用TF-IDF进行特征提取"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "499d9630",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(11314, 129782)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.feature_extraction.text import TfidfTransformer\n",
    "\n",
    "tf_transformer = TfidfTransformer(use_idf = False).fit(X_train_counts)\n",
    "X_train_tf = tf_transformer.transform(X_train_counts)\n",
    "X_train_tf.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "e4d72fe5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(11314, 129782)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tfidf_transformer = TfidfTransformer()\n",
    "X_train_tfidf = tfidf_transformer.fit_transform(X_train_counts)\n",
    "X_train_tfidf.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c8e8640",
   "metadata": {},
   "source": [
    "# 分类器训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "2c24693b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.naive_bayes import MultinomialNB\n",
    "clf = MultinomialNB().fit(X_train_tfidf,twenty_train.target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "52fdbb13",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "1\n"
     ]
    }
   ],
   "source": [
    "docs_new = ['God is love','OpenGL on the GPU is fast']\n",
    "X_new_counts = count_vect.transform(docs_new)\n",
    "X_new_tfidf = tfidf_transformer.transform(X_new_counts)\n",
    "\n",
    "predicted = clf.predict(X_new_tfidf)\n",
    "\n",
    "for doc,category in zip(docs_new,predicted):\n",
    "    print(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "160fa0a8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
